package com.offcn.bigdata.spark.sql.p2

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}

/**
  * SparkSQL的函数操作
  * udaf的操作
  */
object _02SparkSQLWithFunctioOps {
    def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder()
                            .appName("_02SparkSQLWithFunctioOps")
                            .master("local[*]")
                            .getOrCreate()
        spark.udf.register("myAvg", new MyAvgUDAF())

        val pdf = spark.read.json("file:/E:/data/spark/sql/people.json")


        pdf.createOrReplaceTempView("person")

        spark.sql(
            """
              |select
              |   province,
              |   round(avg(height), 2) avg_height,
              |   round(myAvg(height), 2) my_avg_height
              |from person
              |group by province
            """.stripMargin)
            .show()

        spark.stop()
    }
}

class MyAvgUDAF extends UserDefinedAggregateFunction {
    /**
      *  该自定义udaf的输入参数列表的schema信息
      */
    override def inputSchema: StructType = {
        StructType(Array(
            StructField("height", DataTypes.DoubleType, false)
        ))
    }

    /**
      * 为了实现该udaf需要定义的临时变量构成的schema
      * 在这里就需要两个：
      *     sum
      *     count
      */
    override def bufferSchema: StructType = {
        StructType(Array(
            StructField("sum", DataTypes.DoubleType, false),
            StructField("count", DataTypes.IntegerType, false)
        ))
    }

    //该自定义函数的返回值的数据类型
    override def dataType: DataType = DataTypes.DoubleType

    override def deterministic: Boolean = true
    //最后该udaf的结果返回值
    override def evaluate(buffer: Row): Double = {
        val sum = buffer.getDouble(0)
        val count = buffer.getInt(1)
        sum / count
    }

    //初始化
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer.update(0, 0.0d)//var sum = 0.0
        buffer.update(1, 0)//var count = 0
    }

    //分区内聚合
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        val historySum = buffer.getDouble(0)
        val currentHeight = input.getDouble(0)
        val sumHeight = historySum + currentHeight
        buffer.update(0, sumHeight)//sum += i
        val historyCount = buffer.getInt(1)
        buffer.update(1, historyCount + 1) //count++
    }

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        val historySum = buffer1.getDouble(0)
        val currentHeight = buffer2.getDouble(0)
        val sumHeight = historySum + currentHeight

        buffer1.update(0, sumHeight)//sum += i

        val historyCount = buffer1.getInt(1)
        val currentCount = buffer2.getInt(1)
        val count = historyCount + currentCount

        buffer1.update(1, count)
    }
}

